{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "1_getting_started.ipynb",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hyyN-2qyK_T2",
        "colab_type": "text"
      },
      "source": [
        "# Stable Baselines Tutorial - Getting Started\n",
        "\n",
        "Github repo: https://github.com/araffin/rl-tutorial-jnrr19\n",
        "\n",
        "Stable-Baselines: https://github.com/hill-a/stable-baselines\n",
        "\n",
        "Documentation: https://stable-baselines.readthedocs.io/en/master/\n",
        "\n",
        "RL Baselines zoo: https://github.com/araffin/rl-baselines-zoo\n",
        "\n",
        "Medium article: [https://medium.com/@araffin/stable-baselines-a-fork-of-openai-baselines-df87c4b2fc82](https://medium.com/@araffin/stable-baselines-a-fork-of-openai-baselines-df87c4b2fc82)\n",
        "\n",
        "[RL Baselines Zoo](https://github.com/araffin/rl-baselines-zoo) is a collection of pre-trained Reinforcement Learning agents using Stable-Baselines.\n",
        "\n",
        "It also provides basic scripts for training, evaluating agents, tuning hyperparameters and recording videos.\n",
        "\n",
        "\n",
        "## Introduction\n",
        "\n",
        "In this notebook, you will learn the basics for using stable baselines library: how to create a RL model, train it and evaluate it. Because all algorithms share the same interface, we will see how simple it is to switch from one algorithm to another.\n",
        "\n",
        "\n",
        "## Install Dependencies and Stable Baselines Using Pip\n",
        "\n",
        "List of full dependencies can be found in the [README](https://github.com/hill-a/stable-baselines).\n",
        "\n",
        "```\n",
        "sudo apt-get update && sudo apt-get install cmake libopenmpi-dev zlib1g-dev\n",
        "```\n",
        "\n",
        "\n",
        "```\n",
        "pip install stable-baselines[mpi]\n",
        "```"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "gWskDE2c9WoN",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Stable Baselines only supports tensorflow 1.x for now\n",
        "%tensorflow_version 1.x\n",
        "!apt-get install ffmpeg freeglut3-dev xvfb  # For visualization\n",
        "!pip install stable-baselines[mpi]==2.10.2"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "U29X1-B-AIKE",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import stable_baselines\n",
        "stable_baselines.__version__"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FtY8FhliLsGm",
        "colab_type": "text"
      },
      "source": [
        "## Imports"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gcX8hEcaUpR0",
        "colab_type": "text"
      },
      "source": [
        "Stable-Baselines works on environments that follow the [gym interface](https://stable-baselines.readthedocs.io/en/master/guide/custom_env.html).\n",
        "You can find a list of available environment [here](https://gym.openai.com/envs/#classic_control).\n",
        "\n",
        "It is also recommended to check the [source code](https://github.com/openai/gym) to learn more about the observation and action space of each env, as gym does not have a proper documentation.\n",
        "Not all algorithms can work with all action spaces, you can find more in this [recap table](https://stable-baselines.readthedocs.io/en/master/guide/algos.html)"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "BIedd7Pz9sOs",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import gym\n",
        "import numpy as np"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ae32CtgzTG3R",
        "colab_type": "text"
      },
      "source": [
        "The first thing you need to import is the RL model, check the documentation to know what you can use on which problem"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "R7tKaBFrTR0a",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "from stable_baselines import PPO2"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-0_8OQbOTTNT",
        "colab_type": "text"
      },
      "source": [
        "The next thing you need to import is the policy class that will be used to create the networks (for the policy/value functions).\n",
        "This step is optional as you can directly use strings in the constructor: \n",
        "\n",
        "```PPO2('MlpPolicy', env)``` instead of ```PPO2(MlpPolicy, env)```\n",
        "\n",
        "Note that some algorithms like `SAC` have their own `MlpPolicy` (different from `stable_baselines.common.policies.MlPolicy`), that's why using string for the policy is the recommened option."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ROUJr675TT01",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "from stable_baselines.common.policies import MlpPolicy"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RapkYvTXL7Cd",
        "colab_type": "text"
      },
      "source": [
        "## Create the Gym env and instantiate the agent\n",
        "\n",
        "For this example, we will use CartPole environment, a classic control problem.\n",
        "\n",
        "\"A pole is attached by an un-actuated joint to a cart, which moves along a frictionless track. The system is controlled by applying a force of +1 or -1 to the cart. The pendulum starts upright, and the goal is to prevent it from falling over. A reward of +1 is provided for every timestep that the pole remains upright. \"\n",
        "\n",
        "Cartpole environment: [https://gym.openai.com/envs/CartPole-v1/](https://gym.openai.com/envs/CartPole-v1/)\n",
        "\n",
        "![Cartpole](https://cdn-images-1.medium.com/max/1143/1*h4WTQNVIsvMXJTCpXm_TAw.gif)\n",
        "\n",
        "Note: vectorized environments allow to easily multiprocess training. In this example, we are using only one process, hence the DummyVecEnv.\n",
        "\n",
        "We chose the MlpPolicy because input of CartPole is a feature vector, not images.\n",
        "\n",
        "The type of action to use (discrete/continuous) will be automatically deduced from the environment action space\n",
        "\n",
        "\n",
        "Here we are using the [Proximal Policy Optimization](https://stable-baselines.readthedocs.io/en/master/modules/ppo2.html) algorithm (PPO2 is the version optimized for GPU), which is an Actor-Critic method: it uses a value function to improve the policy gradient descent (by reducing the variance).\n",
        "\n",
        "It combines ideas from [A2C](https://stable-baselines.readthedocs.io/en/master/modules/a2c.html) (having multiple workers and using an entropy bonus for exploration) and [TRPO](https://stable-baselines.readthedocs.io/en/master/modules/trpo.html) (it uses a trust region to improve stability and avoid catastrophic drops in performance).\n",
        "\n",
        "PPO is an on-policy algorithm, which means that the trajectories used to update the networks must be collected using the latest policy.\n",
        "It is usually less sample efficient than off-policy alorithms like [DQN](https://stable-baselines.readthedocs.io/en/master/modules/dqn.html), [SAC](https://stable-baselines.readthedocs.io/en/master/modules/sac.html) or [TD3](https://stable-baselines.readthedocs.io/en/master/modules/td3.html), but is much faster regarding wall-clock time.\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "pUWGZp3i9wyf",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "env = gym.make('CartPole-v1')\n",
        "# vectorized environments allow to easily multiprocess training\n",
        "# we demonstrate its usefulness in the next examples\n",
        "# env = DummyVecEnv([lambda: env])\n",
        "\n",
        "model = PPO2(MlpPolicy, env, verbose=0)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4efFdrQ7MBvl",
        "colab_type": "text"
      },
      "source": [
        "We create a helper function to evaluate the agent:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "63M8mSKR-6Zt",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def evaluate(model, num_episodes=100):\n",
        "    \"\"\"\n",
        "    Evaluate a RL agent\n",
        "    :param model: (BaseRLModel object) the RL Agent\n",
        "    :param num_episodes: (int) number of episodes to evaluate it\n",
        "    :return: (float) Mean reward for the last num_episodes\n",
        "    \"\"\"\n",
        "    # This function will only work for a single Environment\n",
        "    env = model.get_env()\n",
        "    all_episode_rewards = []\n",
        "    for i in range(num_episodes):\n",
        "        episode_rewards = []\n",
        "        done = False\n",
        "        obs = env.reset()\n",
        "        while not done:\n",
        "            # _states are only useful when using LSTM policies\n",
        "            action, _states = model.predict(obs)\n",
        "            # here, action, rewards and dones are arrays\n",
        "            # because we are using vectorized env\n",
        "            obs, reward, done, info = env.step(action)\n",
        "            episode_rewards.append(reward)\n",
        "\n",
        "        all_episode_rewards.append(sum(episode_rewards))\n",
        "\n",
        "    mean_episode_reward = np.mean(all_episode_rewards)\n",
        "    print(\"Mean reward:\", mean_episode_reward, \"Num episodes:\", num_episodes)\n",
        "\n",
        "    return mean_episode_reward"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6hkyafs--gJz",
        "colab_type": "text"
      },
      "source": [
        "In fact, Stable-Baselines already provides you with that helper:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "s6ZNldIR-fce",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "from stable_baselines.common.evaluation import evaluate_policy"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zjEVOIY8NVeK",
        "colab_type": "text"
      },
      "source": [
        "Let's evaluate the un-trained agent, this should be a random agent."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "xDHLMA6NFk95",
        "colab_type": "code",
        "outputId": "da86c8ac-2642-4fca-d026-da93d990130a",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 35
        }
      },
      "source": [
        "# Use a separate environement for evaluation\n",
        "eval_env = gym.make('CartPole-v1')\n",
        "\n",
        "# Random Agent, before training\n",
        "mean_reward, std_reward = evaluate_policy(model, eval_env, n_eval_episodes=100)\n",
        "\n",
        "print(f\"mean_reward:{mean_reward:.2f} +/- {std_reward:.2f}\")"
      ],
      "execution_count": 10,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "mean_reward:9.68 +/- 0.68\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "r5UoXTZPNdFE",
        "colab_type": "text"
      },
      "source": [
        "## Train the agent and evaluate it"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "e4cfSXIB-pTF",
        "colab_type": "code",
        "outputId": "14e71abd-5513-49e6-9641-66227be3f165",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 35
        }
      },
      "source": [
        "# Train the agent for 10000 steps\n",
        "model.learn(total_timesteps=10000)"
      ],
      "execution_count": 11,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "<stable_baselines.ppo2.ppo2.PPO2 at 0x7f95a045fac8>"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 11
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ygl_gVmV_QP7",
        "colab_type": "code",
        "outputId": "205c23a7-acc6-47ab-b779-813da1fba49e",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 35
        }
      },
      "source": [
        "# Evaluate the trained agent\n",
        "mean_reward, std_reward = evaluate_policy(model, eval_env, n_eval_episodes=100)\n",
        "\n",
        "print(f\"mean_reward:{mean_reward:.2f} +/- {std_reward:.2f}\")"
      ],
      "execution_count": 12,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "mean_reward:145.87 +/- 77.90\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "A00W6yY3NkHG",
        "colab_type": "text"
      },
      "source": [
        "Apparently the training went well, the mean reward increased a lot ! "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xVm9QPNVwKXN",
        "colab_type": "text"
      },
      "source": [
        "### Prepare video recording"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "MPyfQxD5z26J",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Set up fake display; otherwise rendering will fail\n",
        "import os\n",
        "os.system(\"Xvfb :1 -screen 0 1024x768x24 &\")\n",
        "os.environ['DISPLAY'] = ':1'"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "SLzXxO8VMD6N",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import base64\n",
        "from pathlib import Path\n",
        "\n",
        "from IPython import display as ipythondisplay\n",
        "\n",
        "def show_videos(video_path='', prefix=''):\n",
        "  \"\"\"\n",
        "  Taken from https://github.com/eleurent/highway-env\n",
        "\n",
        "  :param video_path: (str) Path to the folder containing videos\n",
        "  :param prefix: (str) Filter the video, showing only the only starting with this prefix\n",
        "  \"\"\"\n",
        "  html = []\n",
        "  for mp4 in Path(video_path).glob(\"{}*.mp4\".format(prefix)):\n",
        "      video_b64 = base64.b64encode(mp4.read_bytes())\n",
        "      html.append('''<video alt=\"{}\" autoplay \n",
        "                    loop controls style=\"height: 400px;\">\n",
        "                    <source src=\"data:video/mp4;base64,{}\" type=\"video/mp4\" />\n",
        "                </video>'''.format(mp4, video_b64.decode('ascii')))\n",
        "  ipythondisplay.display(ipythondisplay.HTML(data=\"<br>\".join(html)))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LTRNUfulOGaF",
        "colab_type": "text"
      },
      "source": [
        "We will record a video using the [VecVideoRecorder](https://stable-baselines.readthedocs.io/en/master/guide/vec_envs.html#vecvideorecorder) wrapper, you will learn about those wrapper in the next notebook."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Trag9dQpOIhx",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "from stable_baselines.common.vec_env import VecVideoRecorder, DummyVecEnv\n",
        "\n",
        "def record_video(env_id, model, video_length=500, prefix='', video_folder='videos/'):\n",
        "  \"\"\"\n",
        "  :param env_id: (str)\n",
        "  :param model: (RL model)\n",
        "  :param video_length: (int)\n",
        "  :param prefix: (str)\n",
        "  :param video_folder: (str)\n",
        "  \"\"\"\n",
        "  eval_env = DummyVecEnv([lambda: gym.make('CartPole-v1')])\n",
        "  # Start the video at step=0 and record 500 steps\n",
        "  eval_env = VecVideoRecorder(eval_env, video_folder=video_folder,\n",
        "                              record_video_trigger=lambda step: step == 0, video_length=video_length,\n",
        "                              name_prefix=prefix)\n",
        "\n",
        "  obs = eval_env.reset()\n",
        "  for _ in range(video_length):\n",
        "    action, _ = model.predict(obs)\n",
        "    obs, _, _, _ = eval_env.step(action)\n",
        "\n",
        "  # Close the video recorder\n",
        "  eval_env.close()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KOObbeu5MMlR",
        "colab_type": "text"
      },
      "source": [
        "### Visualize trained agent\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "iATu7AiyMQW2",
        "colab_type": "code",
        "outputId": "45feec21-11b5-4aea-c580-2993fc3d2b81",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 35
        }
      },
      "source": [
        "record_video('CartPole-v1', model, video_length=500, prefix='ppo2-cartpole')"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Saving video to  /content/videos/ppo2-cartpole-step-0-to-step-500.mp4\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "-n4i-fW3NojZ",
        "colab_type": "code",
        "outputId": "345d869b-3bc6-48eb-86d2-949fadf5d4fe",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 420
        }
      },
      "source": [
        "show_videos('videos', prefix='ppo2')"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/html": [
              ""
            ],
            "text/plain": [
              ""
            ]
          },
          "metadata": {
            "tags": []
          }
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9Y8zg4V566qD",
        "colab_type": "text"
      },
      "source": [
        "## Bonus: Train a RL Model in One Line\n",
        "\n",
        "The policy class to use will be inferred and the environment will be automatically created. This works because both are [registered](https://stable-baselines.readthedocs.io/en/master/guide/quickstart.html)."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "iaOPfOrwWEP4",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "model = PPO2('MlpPolicy', \"CartPole-v1\", verbose=1).learn(1000)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Se8SHBN17Fy4",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        ""
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WBf84jU05mgi",
        "colab_type": "text"
      },
      "source": [
        "## Train a DQN agent\n",
        "\n",
        "In the previous example, we have used PPO, which one of the many algorithms provided by stable-baselines.\n",
        "\n",
        "In the next example, we are going train a [Deep Q-Network agent (DQN)](https://stable-baselines.readthedocs.io/en/master/modules/dqn.html), and try to see possible improvements provided by its extensions (Double-DQN, Dueling-DQN, Prioritized Experience Replay).\n",
        "\n",
        "The essential point of this section is to show you how simple it is to tweak hyperparameters.\n",
        "\n",
        "The main advantage of stable-baselines is that it provides a common interface to use the algorithms, so the code will be quite similar.\n",
        "\n",
        "\n",
        "DQN paper: https://arxiv.org/abs/1312.5602\n",
        "\n",
        "Dueling DQN: https://arxiv.org/abs/1511.06581\n",
        "\n",
        "Double-Q Learning: https://arxiv.org/abs/1509.06461\n",
        "\n",
        "Prioritized Experience Replay: https://arxiv.org/abs/1511.05952"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kikNl-tT92Ae",
        "colab_type": "text"
      },
      "source": [
        "### Vanilla DQN: DQN without extensions"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "tdLdYABz8TJx",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Same as before we instantiate the agent along with the environment\n",
        "from stable_baselines import DQN\n",
        "\n",
        "# Deactivate all the DQN extensions to have the original version\n",
        "# In practice, it is recommend to have them activated\n",
        "kwargs = {'double_q': False, 'prioritized_replay': False, 'policy_kwargs': dict(dueling=False)}\n",
        "\n",
        "# Note that the MlpPolicy of DQN is different from the one of PPO\n",
        "# but stable-baselines handles that automatically if you pass a string\n",
        "dqn_model = DQN('MlpPolicy', 'CartPole-v1', verbose=1, **kwargs)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "E4A_IOZ_9fzW",
        "colab_type": "code",
        "outputId": "b682bbf9-5de9-40d2-f9ff-b031598c5dd2",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 35
        }
      },
      "source": [
        "# Random Agent, before training\n",
        "mean_reward_before_train = evaluate(dqn_model, num_episodes=100)"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Mean reward: 9.29 Num episodes: 100\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "4wbuvAKk9spH",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Train the agent for 10000 steps\n",
        "dqn_model.learn(total_timesteps=10000, log_interval=10)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "QDQTdpYv9xJN",
        "colab_type": "code",
        "outputId": "363f5ed3-5bb1-4ce2-e218-10018818ecc9",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 35
        }
      },
      "source": [
        "# Evaluate the trained agent\n",
        "mean_reward = evaluate(dqn_model, num_episodes=100)"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Mean reward: 130.02 Num episodes: 100\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GFJvqsMl96l7",
        "colab_type": "text"
      },
      "source": [
        "### DQN + Prioritized Replay"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "roCSjGu69-lA",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Activate only the prioritized replay\n",
        "kwargs = {'double_q': False, 'prioritized_replay': True, 'policy_kwargs': dict(dueling=False)}\n",
        "\n",
        "dqn_per_model = DQN('MlpPolicy', 'CartPole-v1', verbose=1, **kwargs)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "DLnLhos5-lRm",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "dqn_per_model.learn(total_timesteps=10000, log_interval=10)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "xQjF5S_g-mFN",
        "colab_type": "code",
        "outputId": "4ce14e23-8f2b-40c1-90e1-b333d36f6997",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 35
        }
      },
      "source": [
        "# Evaluate the trained agent\n",
        "mean_reward = evaluate(dqn_per_model, num_episodes=100)"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Mean reward: 110.18 Num episodes: 100\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Skny8MUN9_Ky",
        "colab_type": "text"
      },
      "source": [
        "### DQN + Prioritized Experience Replay + Double Q-Learning + Dueling"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "fDCys-Vg-yYR",
        "colab_type": "code",
        "outputId": "a3bc7511-d247-47ca-952e-53e95a1597a6",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 35
        }
      },
      "source": [
        "# Activate all extensions\n",
        "kwargs = {'double_q': True, 'prioritized_replay': True, 'policy_kwargs': dict(dueling=True)}\n",
        "\n",
        "dqn_full_model = DQN('MlpPolicy', 'CartPole-v1', verbose=1, **kwargs)"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Creating environment from the given name, wrapped in a DummyVecEnv.\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "koHGB-VN-81O",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "dqn_full_model.learn(total_timesteps=10000, log_interval=10)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "6cHRULdp--nN",
        "colab_type": "code",
        "outputId": "5b2a3471-f6fe-4594-8f91-e4c62505a07a",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 35
        }
      },
      "source": [
        "mean_reward = evaluate(dqn_per_model, num_episodes=100)"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Mean reward: 110.02 Num episodes: 100\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "n4Q9dR3UC5Zb",
        "colab_type": "text"
      },
      "source": [
        "In this particular example, the extensions does not seem to give any improvement compared to the simple DQN version.\n",
        "They are several reasons for that:\n",
        "\n",
        "1. `CartPole-v1` is a pretty simple environment\n",
        "2. We trained DQN for very few timesteps, not enough to see any difference\n",
        "3. The default hyperparameters for DQN are tuned for atari games, where the number of training timesteps is much larger (10^6) and input observations are images\n",
        "4. We have only compared one random seed per experiment"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FrI6f5fWnzp-",
        "colab_type": "text"
      },
      "source": [
        "## Conclusion\n",
        "\n",
        "In this notebook we have seen:\n",
        "- how to define and train a RL model using stable baselines, it takes only one line of code ;)\n",
        "- how to use different RL algorithms and change some hyperparameters"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "73ji3gbNDkf7",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        ""
      ],
      "execution_count": 0,
      "outputs": []
    }
  ]
}
